#pragma once

#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/maxinfo_propagator.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

class SchedulerRuntimeInfo;
class ExpressionEvaluator;
class HeuristicSummary;

namespace scheduler_utils {

// Assume any only half of the register file is available to spend on buffers,
// this is because when we allocate a buffer in register is has to be accesed
// with a compile time coonstant index. Unfortunately nvcc seems to be using
// many registers for indexing. This is a bad estimation of extra register use,
// but it's hard to get a better one.
constexpr int64_t register_file_size = 256 * 1024 / 2;
constexpr int64_t x_grid_limit = ((int64_t)1 << (int64_t)31) - (int64_t)1;
constexpr int64_t y_grid_limit = 65535;
constexpr int64_t z_grid_limit = 65535;
constexpr int64_t z_block_limit = 64;

// Largest Power of 2 less-than n
constexpr int64_t lastPow2(int64_t n) {
  TORCH_INTERNAL_ASSERT(n >= 0);
  n |= (n >> 1);
  n |= (n >> 2);
  n |= (n >> 4);
  n |= (n >> 8); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
  n |= (n >> 16); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
  n |= (n >> 32); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
  return std::max((int64_t)1, n - (n >> 1));
}

// Div x by y, but min at 1
inline int64_t safeDiv(const int64_t x, const int64_t y) {
  return std::max(x / y, (int64_t)1);
}

// Split the given dimensions in `to_split`. Also update the dimensions in
// `to_update` to the positions in the splitted tensor. Splitting one dimension
// multiple times is supported, and if this is the case, then the order of
// `to_split` matters. All given dimensions are numbers before any split.
TORCH_CUDA_CU_API void splitDims(
    TensorView* tv,
    std::vector<std::pair<size_t, size_t>> to_split, // (dim, size)
    std::vector<size_t>& to_update);

TORCH_CUDA_CU_API inline void splitDims(
    TensorView* tv,
    std::vector<std::pair<size_t, size_t>> to_split) { // (dim, size)
  std::vector<size_t> unused;
  splitDims(tv, std::move(to_split), unused);
}

// Merge all the given dimensions in `to_merge` into a single dimension. Also
// update the dimensions in `to_update` to the positions in the merged tensor.
// Returns the merged dimension. All given dimensions are numbers before any
// merge.
TORCH_CUDA_CU_API c10::optional<size_t> mergeDims(
    TensorView* tv,
    std::vector<size_t> to_merge,
    std::vector<size_t>& to_update);

TORCH_CUDA_CU_API inline c10::optional<size_t> mergeDims(
    TensorView* tv,
    std::vector<size_t> to_merge) {
  std::vector<size_t> unused;
  return mergeDims(tv, std::move(to_merge), unused);
}

// Merge all reduction to the right side and returns total number of
// reduction axes. Don't merge is typically used for trivial reductions.
size_t mergeReduction(
    TensorView* tv,
    const std::unordered_set<IterDomain*>& dont_merge = {});

// merge all non-reduction axes to the left side and returns total number of
// iteration axes. Don't merge is typically used for trivial reductions.
size_t mergeNonReduction(
    TensorView* tv,
    const std::unordered_set<IterDomain*>& dont_merge = {});

// Propagate the parallelization from the selected dimensions of the reference
// tensor to their corresponding dimensions in all selected tensors in the DAG.
// Position `pos` means selecting all the dimensions [0, 1, ..., pos - 1]. pos =
// -1 means selecting all dimensions. `selected_tvs` are selected tensors in the
// DAG. Empty `selected_tvs` means selecting all tensors in the fusion of
// `reference_tv`. `selected_parallel_types` are the selected parallel types.
// Empty `selected_parallel_types` means selecting all parallel types.
TORCH_CUDA_CU_API void parallelizeAllLike(
    TensorView* reference_tv,
    int64_t pos = -1,
    std::vector<TensorView*> selected_tvs = {},
    const std::unordered_set<ParallelType>& selected_parallel_types = {},
    bool propagate_padding = true);

TORCH_CUDA_CU_API inline void parallelizeAllLike(
    TensorView* reference_tv,
    std::vector<TensorView*> selected_tvs,
    const std::unordered_set<ParallelType>& selected_parallel_types = {},
    bool propagate_padding = true) {
  parallelizeAllLike(
      reference_tv,
      -1,
      std::move(selected_tvs),
      selected_parallel_types,
      propagate_padding);
}

TORCH_CUDA_CU_API void computeAtInputs(
    TensorView* consumer,
    int pos,
    ComputeAtMode mode = ComputeAtMode::Standard);

TORCH_CUDA_CU_API void computeWithOutputs(
    TensorView* producer,
    int pos,
    ComputeAtMode mode = ComputeAtMode::Standard);

struct PersistentBufferInfo {
  std::vector<TensorView*> persistent_buffers;
  std::unordered_set<IterDomain*> unmappable_dims;

  // Persistent buffers are needed until the path through the reduction -
  // broadcast chain is resolved by any other chain using the persistent buffer
  // that is not going through a reduction. This assumes all reduction paths
  // have the same reduction pattern. Order is the same as persistent_buffers
  std::vector<std::vector<TensorView*>> persistent_buffer_resolution_points;

  // Not all persistent buffers can be projected to inputs, if a buffer can be
  // projected to the inputs which may reduce the persistent buffer size (BN
  // Backwards specifically) then keep track of it here. Persistent buffers that
  // have a persistent buffer/reduction before them should not be projected
  // through that.
  std::vector<TensorView*> projectable_persistent_buffers;

  // Track inputs of input projectable buffers
  std::vector<TensorView*> projectable_buffer_inputs;

  // Map unmappable dims to projectable_buffer_inputs
  std::unordered_set<IterDomain*> unamppable_dims_projected_to_inputs;
};

// Buffers whos roots can't map to all producer roots based on compute at. These
// are the buffers we would make persistent in a persistent kerenl or would have
// to recompute if we can't make a persistent kernel. This function will also
// return inputs as being marked persistent if they follow this pattern. It is
// important to note however inputs don't strictly have to be persistent as they
// can simply be read multiple times from GMEM in the same kernel.
TORCH_CUDA_CU_API PersistentBufferInfo persistentBuffers(Fusion* fusion);

struct TvProperties {
  // How many elements in tensor view are there to reduce.
  int64_t total_reduction_numel = 1;

  // How many reductions do we need to perform, i.e. how many iter dimension.
  // elements are there
  int64_t total_iteration_numel = 1;

  // Is the inner most dimension a reduction, if no reductions mark true.
  bool fastest_dim_reduction = true;

  // How many elements in the inner most dimension merging surrounding domains
  // that match in type. This is used for 3D schedulers in
  // reduction/normalization.
  int64_t inner_most_dimension_numel = 1;

  // Same thing as above, but the number of dimensions instead of the numel.
  int64_t inner_most_dimension_ndims = 1;

  // Merging neighboring iteration domains, and reduction domains, what's the
  // resulting dimensionality of the problem.
  int64_t dimensionality = 1;
};

// Fill TvProperties structure about tv
TvProperties getProperties(
    Fusion* fusion,
    SchedulerRuntimeInfo& runtime_info,
    TensorView* tv);

// Struct to store persistent buffer sizes. also holds the persistent buffer
// size of the buffers are projected to the inputs.
struct PersistentBufferSizeReturn {
  int64_t persistent_buffer_size = 0;
  int64_t projected_persistent_buffer_size = 0;
};

// Compute the amount of register space would be needed to perform this kernel
// persistently, only based on buffers that must be persistent, and based on the
// maximum of all minimum size requirement. i.e. if must be persistent, only
// hold persistent dimension.
TORCH_CUDA_CU_API PersistentBufferSizeReturn persistentBufferSize(
    Fusion* fusion,
    SchedulerRuntimeInfo& runtime_info,
    PersistentBufferInfo& persistent_buffers,
    HeuristicSummary* data_cache = nullptr);

// Returns a set of all iteration domains (in roots of tensors) that map to a
// trivial reduction
std::unordered_set<IterDomain*> getTrivialReductionMap(Fusion* fusion);

// Merges tensor view to the form:
// [IterationDomain, ReductionDomain, TrivialReductionDim0,
// TrivialReductionDim1, ...] Returns if <iteration dimensions, reduction
// dimensions>
std::pair<bool, bool> canonicalDimReduction(
    Fusion* fusion,
    TensorView* tv,
    bool schedule_3D = false);

// Return a list of tensor views that are outputs of reduction operations. If
// multiple outputs of an expression are found, only include one in the list
TORCH_CUDA_CU_API std::vector<TensorView*> getReductionTvs(
    Fusion* fusion,
    bool ignore_trivial = true);

// Returns a list of TensorViews that are the consumer tv for a view operation.
std::vector<TensorView*> getViewTVs(Fusion* fusion);

// Reset inputs and outputs to global memory, everything else to local.
void clearMemorySpace(Fusion* fusion);

// Returns cached after tensors of the fusion inputs if unrolled. Otherwise
// return empty vector.
TORCH_CUDA_CU_API std::vector<TensorView*> cacheInputs(
    Fusion* fusion,
    bool unroll);

// Returns the pairs of <cache of each fusion output, corresponding output> for
// all outputs.
TORCH_CUDA_CU_API std::vector<std::pair<TensorView*, TensorView*>>
cacheAndForkOutputs(Fusion* fusion, bool unroll);

// Ignores broadcast and reduction, returns iter domain in root domain that's
// "inner most". If this is an rfactored reduction domain, actually check the
// root domain, this is because the rfactored reduction tensorview has the
// vectorized dimension, but that means the rfactor domain could have reordered
// what we consider the "inner most" allocated position on it if we consider the
// rfactor dimension.
//
// If reduction tv and has rfactor return root domain, otherwise return rfactor
// domain.
IterDomain* innerMostRootDim(TensorView* tv);

// Looks through fusion and finds all dims that match to the one provided in
// the tensorview provided. Iter domain must be a root domain. If inner_only,
// will only map dimensions if they're the inner most position. This is
// important when projecting a dimension between an rfactor position and its
// root position when mapping from consumer to producer. If inner_only=true,
// takes the rfactor/root dimensions that maps, projects it to the root/rfactor
// domain, but only following the inner most pass when encounting split/merge.
// When propagating backward, for split it will only propagate backwards if the
// mapped dimension is the inner portion of the split. For merge, inner_only
// doesn't make a dimension and will propagate through the inner portion of the
// merge. When propagating forward, the logic is symmetric with the backward
// case.
class FindAllMappedDims : public MaxInfoSpanningTree::Propagator {
  std::unordered_map<TensorView*, IterDomain*> mapped_root_ids_;
  std::unordered_map<TensorView*, IterDomain*> mapped_rfactor_ids_;
  TensorView* starting_tv_ = nullptr;
  IterDomain* starting_id_ = nullptr;
  bool inner_only_;

 public:
  FindAllMappedDims(TensorView* from, IterDomain* starting_id, bool inner_only);
  virtual void setUp() override;
  virtual void propagateC2P(TensorView* from, TensorView* to) override;
  virtual void propagateP2C(TensorView* from, TensorView* to) override;
  virtual void propagateSibling(TensorView* from, TensorView* to) override;
  std::unordered_set<IterDomain*> get() const;
};

// Checks if tensor view has an iteration domain in vector dims in its inner
// most root position (excluding broadcast and reduction), and checks if it is a
// contiguous dimension
bool hasInnerDim(
    TensorView* tv,
    std::unordered_set<IterDomain*> vector_dims,
    bool should_vectorize);

// Returns all inputs and outputs that share the inner most dimension of the
// provided reference. If reference is an input it ignores reduction axes, will
// ignore all broadcast axes. If inner_only, will require inner->inner mapping
// in view, otherwise, it allows all inner->any mapping. If vectorize_pass, will
// check contiguity for vectorization, otherwise it just checks it has that
// inner dim.
std::vector<TensorView*> getInputsOutputsWithInnerDim(
    TensorView* reference_tv,
    bool inner_only,
    bool vectorize_pass);

// Structure to hold byte multiples for break points. I.e. if we have the
// tensors:
// T0[I0, I1] float
// T1[I0, I1] bool
// T2[I0]     half
// T3    [I1] double
// and a break point of 1 the multiples would be:
// lhs_multiple = 4 + 1 + 2 = 7
// rhs_multiple = 4 + 1 + 8 = 13
struct BroadcastMultiple {
  int64_t rhs_multiple = 0;
  int64_t lhs_multiple = 0;
};

// Returns a vector of counts, size = reference_tv->getRootDomain().size(), each
// entry [i] is the number of inputs/outputs that have a non-broadcast dimension
// mapped to the corresponding dimension in reference_tv. Count includes
// reference_tv if reference_tv is an input or output. Count is multiplied by
// data type size.
std::vector<BroadcastMultiple> getBroadcastMultiples(
    TensorView* reference_tv,
    DataType index_type);

//! Collect maximum vectorization word size of a tensor whose
//! innermost domain is leaf_merged_domain. Contig merging is taken
//! into account to expand vectorization if possible.
size_t collectMaxVectorizeSizeWithContigMerge(
    TensorView* tv,
    IterDomain* leaf_merged_domain,
    size_t max_word_size_in_byte,
    ExpressionEvaluator& expression_evaluator,
    DataType index_type);

namespace matmul_utils {
//! Utilities in this namespace facilitates scheduling matmul kernels with
//!  hierarchichal tiling specified in MatMulTileOptions.

//! Schedule utility for matmul prolog:
//!   Use all the threads on a CTA tile to load matmul operands
//!  into shared memory with the given vectorization word.
//! TODO:
//!  will need to add bank conflict removal swizzle in a follow up.
TORCH_CUDA_CU_API void scheduleContiguousVectorLoad(
    TensorView* tv,
    MatMulTileOptions tile,
    int vector_word,
    bool vectorize = true);

//! Schedule utility for mma output in matmul main loop:
//!  Realize the hierarchical tiling based on the given tiling options.
//! TODO: rewrite this one with makeTile
TORCH_CUDA_CU_API void scheduleWarpTileWithReduction(
    TensorView* tv,
    MatMulTileOptions tile);

//! Schedule utility for mma output in matmul main loop:
//!  Realize the hierarchical tiling based on the given tiling options
//! on consumers of mma ops in epilog.
//! TODO: remove this one eventually.
TORCH_CUDA_CU_API void scheduleWarpTileWithNoReduction(
    TensorView* tv,
    MatMulTileOptions tile);

//! Lower level primitive spliting inner iterdomains into tiles:
//! Eg.
//!  A[B,I0,I1,I2] -> makeTile({1,2,3})
//! Gives A[B, I0o, I1o, I2o, I0i(1), I1i(2), I2i(3)]
TORCH_CUDA_CU_API void makeTile(TensorView* tv, std::vector<int> tile_sizes);

//! Order the inner tile dimensions as the original order in
//!  root domain. Also putting broadcast domains on the left.
//! Eg. A[I0o,I1o,B2o,I0i,I1i,B2i] (root domain: I1,B,I0)
//! -> A[I0o, I1o, B2o, B2i, I1i, I0i]
//! This is used to facilitate data layout swizzling and
//!  defining vectorized loads.
TORCH_CUDA_CU_API void orderTiledConcreteIdAsRoot(TensorView* tv);

//! Orders the root id ordering of the given tv as
//! [Batch, Previous Reduction, M, N, K]
//!  for easier processing of later scheduling steps.
//!
//! This matching works on root domain only, and
//!  will throw if the tv has a leaf iterdomain that is
//!  not a root id.
TORCH_CUDA_CU_API void canonicalizeMmaTvOrdering(TensorView* tv);

} // namespace matmul_utils

//! Propagate current transformations on from_tv up to the given
//!  position, to all tensorviews on the owning fusion that has
//!  a connection with `from_tv` on the fusion graph.
TORCH_CUDA_CU_API void transformPropagateToAllFrom(
    TensorView* from_tv,
    int pos);

//! A type of custom transform propagator that propagates iterdomain
//!  transforms from a source tv to all tvs that are selected
//!  using a "direction" and a "boundary".
//!
//! The propagation model always assumes a `from_tv`, a `direction` and a
//! `boundary`.
//!
//! This propagator will only transform producers and consumers
//! of `from_tv`, and all propagation modes **require** a boundary to be
//! specified to signify where the propagation should stop.
//!
//! There are currently three modes of propagation: forward, backward and
//! both-way, see comment on the interface functions for details.
struct TORCH_CUDA_CU_API BoundedDirectionalTransformPropagator {
  //! Custom option container for configuring
  //!  the transform propagation actions.
  //! All option values default to false unless
  //!  the corresponding setter is called.
  struct Options {
    //! If true, the transform propagator will
    //!   also propagate parallel types from
    //!   `from_tv` to all selected tvs.
    bool propagate_parallel_type = false;

    //! If true, the specified boundary tvs
    //!  will also be replayed as `from_tv`.
    //!  If false, they will not be affected
    //!  by the propagation pass.
    bool transform_boundary = false;

    //! Sets the position boundary in parallel
    //!  type propagation, see comment on
    //!  scheduler_utils::parallelizeAllLike.
    //! Only used if propagate_parallel_type==true.
    int parallel_propagation_pos = -1;

    //! Setter for enabling parallel type
    //!  propagation. see comment on the variable.
    //!
    //! \param up_to_pos, sets the parallel type
    //!  propagation boundary. see comment on
    //!  scheduler_utils::parallelizeAllLike.
    Options propagateParallelType(int up_to_pos = -1) {
      propagate_parallel_type = true;
      parallel_propagation_pos = up_to_pos;
      return *this;
    }

    //! Setter for enabling propagation to
    //!  boundary tvs. see comment on the variable
    Options propagateToBoundary() {
      transform_boundary = true;
      return *this;
    }
  };

  //! Replay transforms from tensorview `from`
  //!  to the tensorviews that are consumers
  //!  of boundary tensorviews in `to` and producers of `from`.
  static void backward(
      TensorView* from,
      int pos,
      std::vector<TensorView*> to,
      c10::optional<Options> options = c10::nullopt);

  //! Replay transforms from tensorview `from`
  //! to the tensorviews that are producers
  //!  of boundary tensorviews in `to` and consumers of `from`.
  static void forward(
      TensorView* from,
      int pos,
      std::vector<TensorView*> to,
      c10::optional<Options> options = c10::nullopt);

  //! Replay transforms from tensorview `from`
  //!  to all the tensorviews that are consumers
  //!  of tensorviews in `backward_to` and producers
  //!  of tensorviews in `forward_to` while being
  //!  either a producer or a consumer of tensorview `from`.
  static void bothWays(
      TensorView* from,
      int pos,
      std::vector<TensorView*> backward_to,
      std::vector<TensorView*> forward_to,
      c10::optional<Options> options = c10::nullopt);

 private:
  //! Utility function:
  //!  Will realize the transform propagation to the
  //! tensorview's in `included_tvs`.
  //!  Assumes that all tvs in included_tvs are either
  //! a producer or a consumer of from_tv.
  static void propagate(
      TensorView* from_tv,
      int pos,
      std::unordered_set<TensorView*> included_tvs,
      Options options);
};

} // namespace scheduler_utils
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
